{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Neural network text prediction with `textgenrnn`\n",
    "\n",
    "By [Allison Parrish](http://www.decontextualize.com/)\n",
    "\n",
    "This is a quick tutorial on how to use [Max Woolf](https://github.com/minimaxir)'s [textgenrnn](https://github.com/minimaxir/textgenrnn) to get started generating text with recurrent neural networks.\n",
    "\n",
    "Like a [Markov chain](ngrams-and-markov-chains.ipynb), a recurrent neural network (RNN) is a way to make predictions about what will come next in a sequence. For our purposes, the sequence in question is a sequence of characters, and the prediction we want to make is *which character will come next*. Both Markov models and recurrent neural networks do this by using statistical properties of text to make a *probability distribution* for what character will come next, given some information about what comes before. The two procedures work very differently internally, and we're not going to go into the gory details about implementation here. (But if you're interested in the gory details, [here's a good place to start](https://karpathy.github.io/2015/05/21/rnn-effectiveness/).) For our purposes, the main *functional* difference between a Markov chain and a recurrent neural network is the *portion* of the sequence used to make the prediction. A Markov model uses a fixed window of history from the sequence, while an RNN (theoretically) uses the *entire history* of the sequence.\n",
    "\n",
    "## Start with Markov\n",
    "\n",
    "To illustrate, let's look at the word \"condescendences.\" In a Markov model based on bigrams from this string of characters, you'd make a list of bigrams and the characters that follow those bigrams, like so:\n",
    "\n",
    "| n-grams |\tnext? |\n",
    "| ------- | ----- |\n",
    "| co      | n |\n",
    "| on      | d |\n",
    "| nd      | e, e |\n",
    "| de      | s, n |\n",
    "| es      | c, (end of text) |\n",
    "| sc      | e |\n",
    "| ce      | n, s |\n",
    "| en      | d, c |\n",
    "| nc      | e |\n",
    "\n",
    "You could also write this as a probability distribution, with one column for each bigram. The value in each cell indicates the probability that the character following the bigram in a given row will be followed by the character in a given column:\n",
    "\n",
    "| n-grams | c | o | n | d | e | s | END |\n",
    "| ------- | - | - | - | - | - | - | --- |\n",
    "| co      | 0 | 0 | 1.0 | 0 | 0 | 0 | 0 | \n",
    "| on      | 0 | 0 | 0 | 1.0 | 0 | 0 | 0 | \n",
    "| nd      | 0 | 0 | 0 | 0 | 1.0 | 0 | 0 | \n",
    "| de      | 0 | 0 | 0.5 | 0 | 0 | 0.5 | 0 |\n",
    "| es      | 0.5 | 0 | 0 | 0 | 0 | 0 | 0.5 |\n",
    "| sc      | 0 | 0 | 0 | 0 | 1.0 | 0 | 0 |\n",
    "| ce      | 0 | 0 | 0.5 | 0 | 0 | 0.5 | 0 |\n",
    "| en      | 0.5 | 0 | 0 | 0.5 | 0 | 0 | 0 |\n",
    "| nc      | 0 | 0 | 0 | 0 | 1.0 | 0 | 0 |\n",
    "\n",
    "Each row of this table is a *probability distribution*, meaning that it shows how probable a given letter is to follow the n-gram in the original text. In a probability distribution, all of the values add up to 1.\n",
    "\n",
    "Fitting a Markov model to the data is a matter of looking at each sequence of characters in a given text, and updating the table of probability distributions accordingly. To make a prediction from this table, you can \"sample\" from the probability distribution for a given n-gram (i.e., sampling from the distribution for the bigram `de`, you'd have a 50% chance of picking `n` and a 50% chance of picking `s`).\n",
    "\n",
    "Another way of thinking about this Markov model is as a (hypothetical!) function *f* that takes a bigram as a parameter and returns a probability distribution for that bigram:\n",
    "\n",
    "    f(\"ce\") → [0.0, 0.0, 0.5, 0.0, 0.0, 0.5, 0.0]\n",
    "    \n",
    "(Note that the values at each index in this distribution line up with the columns in the table above.)\n",
    "    \n",
    "The items in the list returned from this function correspond to the probability for the corresponding next character, as given in the table. To sample from this list, you'd pick randomly among the indices according to their probabilities, and then look up the corresponding character by its position in the table.\n",
    "\n",
    "To generate new text from this model:\n",
    "\n",
    "1. Set your output string to a randomly selected n-gram\n",
    "2. Sample a letter from the probability distribution associated with the n-gram at the end of the output string\n",
    "3. Append the sampled letter to the end of the string\n",
    "4. Repeat from (2) until the END token is reached\n",
    "\n",
    "Of course, you don't write this function by hand! When you're creating a Markov model from your data (or \"training\" the model), you're essentially asking the computer to write this function *for you*. In this sense, a Markov model is a very simple kind of machine learning, since the computer \"learns\" the probability distribution from the data that you feed it.\n",
    "\n",
    "## A (very) simplified explanation of RNNs\n",
    "\n",
    "The mechanism by which a recurrent neural network \"learns\" probability distributions from sequences is much more sophisticated than the mechanism used in a Markov model, but functionally they're very similar: you give the computer some data to \"train\" on, and then ask it to automatically create a function that will return a probability distribution of what comes next, given some input. An RNN differs from a Markov chain in that to predict the next item in the sequence, you pass in *the entire sequence* instead of just the most recent n-gram.\n",
    "\n",
    "In other words, you can (again, hypothetically) think of an RNN as a way of automatically creating a function *f* that takes a sequence of characters of arbitrary length and returns a probability distribution for which character comes next in the sequence. Unlike a Markov chain, it's possible to *improve* the accuracy of the probability distribution returned from this function by training on the same data multiple times.\n",
    "\n",
    "Let's say that we want to train the RNN on the string \"condescendences\" to learn this function, and we want to make a prediction about what character is most likely to follow the sequence of characters \"condescendence\". When training a neural network, the process of learning a function like this works iteratively: you start off with a function that essentially gives a uniform probability distribution for each outcome (i.e., no one outcome is considered more likely than any other):\n",
    "\n",
    "    f(\"condescendence\") → [0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14] (after zero passes through the data)\n",
    "    \n",
    "... and as you iterate over the training data (in this case, the word \"condescendences\"), the probability distribution  gradually improves, ideally until it comes to accurately reflect the actual observed distribution (in the parlance, until it \"converges\"). After some number of passes through the data, you might expect the automatically-learned function to return distributions like this:\n",
    "\n",
    "    f(\"condescendence\") → [0.01, 0.02, 0.01, 0.03, 0.01, 0.9, 0.02] (after n passes through the data)\n",
    "\n",
    "A single pass through the training data is called an \"epoch.\" When it comes to any neural network, and RNNs in particular, more epochs is almost always better.\n",
    "\n",
    "To generate text from this model:\n",
    "\n",
    "1. Initialize your output string to an empty string, or a random character, or a starting \"prefix\" that you specify;\n",
    "2. Sample the next letter from the distribution returned for the current output string;\n",
    "3. Append that character to the end of the output string;\n",
    "4. Repeat from (2)\n",
    "\n",
    "Of course, in a real life application of both a Markov model and an RNN, you'd normally have more than seven items in the probability distribution! In fact, you'd have one element in the probability distribution for every possible character that occurs in the text. (Meaning that if there were 100 unique characters found in the text, the probability distribution would have 100 items in it.)\n",
    "\n",
    "## Markov chains vs RNNs    \n",
    "\n",
    "The primary benefit of an RNN over a Markov model for text generation is that an RNN takes into account *the entire history* of a sequence when generating the next character. This means that, for example, an RNN can theoretically learn how to close quotes and parentheses, which a Markov chain will never be able to reliably do (at least for pairs of quotes and parentheses longer than the n-gram of the Markov chain).\n",
    "\n",
    "The drawback of RNNs is that they are *computationally expensive*, from both a processing and memory perspective. This is (again) a simplification, but internally, RNNs work by \"squishing\" information about the training data down into large matrices, and make predictions by performing calculations on these large matrices. That means that you need a lot of CPU and RAM to train an RNN, and the resulting models (when stored to disk) can be very large. Training an RNN also (usually) takes a lot of time.\n",
    "\n",
    "Another consideration is the size of your corpus. Markov models will give interesting and useful results even for very small datasets, but RNNs require large amounts of data to train—the more data the better.\n",
    "\n",
    "So what do you do if you *don't* have a very large corpus? Or if you don't have a lot of time to train on your corpus?\n",
    "\n",
    "## RNN generation from pre-trained models\n",
    "\n",
    "Fortunately for us, developer and data scientist [Max Woolf](https://github.com/minimaxir) has made a Python library called [textgenrnn](https://github.com/minimaxir/textgenrnn) that makes it really easy to experiment with RNN text generation. This library includes a model (according to the documentation) \"trained on hundreds of thousands of text documents, from Reddit submissions (via BigQuery) and Facebook Pages (via my Facebook Page Post Scraper), from a very diverse variety of subreddits/Pages,\" and allows you to use this model as a starting point for your own training.\n",
    "\n",
    "To install textgenrnn, you'll probably want to install [Keras](http://keras.io/) first. With Anaconda:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!conda install -y keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then install textgenrnn with `pip`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install textgenrnn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once it's installed, import the `textgenrnn` class from the package:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from textgenrnn import textgenrnn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And create a new `textgenrnn` object like so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "textgen = textgenrnn()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This object has a `.generate()` method which will, by default, generate text from the pre-trained model only."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking to the stream and a start in the first time to see the strong of the success and a president and the star to the problems in the street to the performance is a statement of the most life in \n",
      "\n"
     ]
    }
   ],
   "source": [
    "textgen.generate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To train it on your own text, use the `.train_on_texts()` method, passing in a list of strings. The `num_epochs` parameter allows you to indicate how many epochs (i.e., passes over the data) should be performed. The more epochs the better, especially for shorter texts, but you'll get okay results even with just a few. For *The Road Not Taken*, twenty epochs worked well for me:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      "754/754 [==============================] - 1s 2ms/step - loss: 0.9443\n",
      "Epoch 2/20\n",
      "754/754 [==============================] - 1s 2ms/step - loss: 0.8913\n",
      "Epoch 3/20\n",
      "754/754 [==============================] - 1s 2ms/step - loss: 0.8354\n",
      "Epoch 4/20\n",
      "754/754 [==============================] - 1s 2ms/step - loss: 0.7948\n",
      "Epoch 5/20\n",
      "754/754 [==============================] - 1s 2ms/step - loss: 0.7595\n",
      "Epoch 6/20\n",
      "754/754 [==============================] - 2s 2ms/step - loss: 0.7419\n",
      "Epoch 7/20\n",
      "754/754 [==============================] - 2s 2ms/step - loss: 0.7108\n",
      "Epoch 8/20\n",
      "754/754 [==============================] - 2s 2ms/step - loss: 0.6971\n",
      "Epoch 9/20\n",
      "754/754 [==============================] - 2s 2ms/step - loss: 0.6671\n",
      "Epoch 10/20\n",
      "754/754 [==============================] - 2s 2ms/step - loss: 0.6484\n",
      "Epoch 11/20\n",
      "754/754 [==============================] - 2s 2ms/step - loss: 0.6317\n",
      "Epoch 12/20\n",
      "754/754 [==============================] - 2s 2ms/step - loss: 0.6174\n",
      "Epoch 13/20\n",
      "754/754 [==============================] - 2s 2ms/step - loss: 0.6048\n",
      "Epoch 14/20\n",
      "754/754 [==============================] - 2s 3ms/step - loss: 0.5942\n",
      "Epoch 15/20\n",
      "754/754 [==============================] - 2s 2ms/step - loss: 0.5849\n",
      "Epoch 16/20\n",
      "754/754 [==============================] - 2s 2ms/step - loss: 0.5769\n",
      "Epoch 17/20\n",
      "754/754 [==============================] - 2s 2ms/step - loss: 0.5711\n",
      "Epoch 18/20\n",
      "754/754 [==============================] - 2s 2ms/step - loss: 0.5641\n",
      "Epoch 19/20\n",
      "754/754 [==============================] - 2s 2ms/step - loss: 0.5597\n",
      "Epoch 20/20\n",
      "754/754 [==============================] - 2s 2ms/step - loss: 0.5564\n"
     ]
    }
   ],
   "source": [
    "textgen.train_on_texts(open(\"frost.txt\").readlines(), num_epochs=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After training, you can generate new text using the `.generate()` method again:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "And both and as down that the first ages and be that the first that the first travelled one as far as I could ages and be about the wood,\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "textgen.generate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The results aren't very interesting because by default the generator is very conservative in how it samples from the probability distribution. You can use the `temperature` parameter to make the sampling a bit more likely to pick improbable outcomes. The higher the value, the weirder the results. The default is 0.2, and going above 1.0 is likely to produce unacceptably strange results:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "In leaves sorry I could sood sorry I having the poor as I had to down the ages in the wive as the diverged in the one traveler, and I stood\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "textgen.generate(temperature=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tow that the one one at morning there in a yell travels to what to worre.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "textgen.generate(temperature=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you pass a number *n* to the `.generate()` method as its first parameter, and the parameter `return_as_list=True`, `.generate()` will return a list of *n* instances of text generation from the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "And as way that other to screase.\n",
      "Because other good on the feel about the first doad that stack,\n",
      "And having the undergrowthy and wanted sororing about the roads one way to one eneust the first and equally bectro equally leaves in as another ages and as fair,\n",
      "And ages in the brothes as mance as for the roads in the one least way tallmorth in for a black.\n",
      "I shall be some way leads both diverged for, but worn in the one less travelled there\n",
      "And be in the condence, and I stood\n",
      "To kent the passing there\n",
      "\n",
      "In and be ages travel both the grasside, and I haved be soober,\n",
      "Then diverged in an iqually least less llastread diverged in a I had a longe as fair,\n"
     ]
    }
   ],
   "source": [
    "poem = textgen.generate(10, temperature=0.8, return_as_list=True)\n",
    "for line in poem:\n",
    "    print(line.strip())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(This may take a little while.)\n",
    "\n",
    "I've found that `textgenrnn` works especially well with very short, word-length texts. For example, download [this file of human moods](https://github.com/dariusk/corpora/blob/master/data/humans/moods.json) from Corpora Project, and put it in the same directory as this notebook. The `textgenrnn` library stores its models globally, so you'll first need to reset the library to its initial state:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "textgen.reset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then load the JSON file and grab just the list of words naming moods:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import json\n",
    "mood_data = json.loads(open(\"./moods.json\").read())\n",
    "moods = mood_data['moods']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, train the RNN on these moods. One epoch will do the trick:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/1\n",
      "6651/6651 [==============================] - 13s 2ms/step - loss: 2.1288\n"
     ]
    }
   ],
   "source": [
    "textgen.train_on_texts(moods, num_epochs=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now generate a list of new moods:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "new_moods = textgen.generate(25, temperature=0.8, return_as_list=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And print them out!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "terered\n",
      "distructive\n",
      "formersaked form\n",
      "usedious\n",
      "worship\n",
      "lose\n",
      "day\n",
      "experientacled\n",
      "sircial\n",
      "influet\n",
      "indeelled\n",
      "buld\n",
      "phoneed\n",
      "ounded\n",
      "guy\n",
      "announced\n",
      "orchosies\n",
      "time\n",
      "boardy\n",
      "reless\n",
      "includive\n",
      "sometifent\n",
      "upcomist\n",
      "parablever\n",
      "suffered\n"
     ]
    }
   ],
   "source": [
    "print('\\n'.join(new_moods))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "I don't know about you, but I'm feeling a little sometifent today."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Further reading\n",
    "\n",
    "* [This notebook from the creator of textgenrnn](https://github.com/minimaxir/textgenrnn/blob/master/docs/textgenrnn-demo.ipynb) covers everything about the library that I covered in this tutorial—and much more, including how to start generation from a particular \"seed\" and how to save and load models (useful if you spent an afternoon training a model on your own corpus and don't want to have to do it again!)\n",
    "* Take a look at [Janelle Shane's wonderful overview of how she uses RNNs in her process](http://aiweirdness.com/faq). And then take a look at her [wonderful creative work with RNNs](http://aiweirdness.com/)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
